{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0-beta1\n",
      "GPU Available:  False\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import glob\n",
    "import pickle\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from metrics import exact_match_metric\n",
    "from callbacks import NValidationSetsCallback, GradientLogger\n",
    "from generator import DataGenerator\n",
    "import time\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../')\n",
    "from src.models.utils import concatenate_texts\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "print(tf.__version__)\n",
    "print(\"GPU Available: \", tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings_path = Path('../../settings/settings.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(str(settings_path), 'r') as file:\n",
    "    settings_dict = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'math_module': 'arithmetic__add_sub',\n",
       " 'train_level': 'easy',\n",
       " 'batch_size': 1024,\n",
       " 'thinking_steps': 0,\n",
       " 'epochs': 1,\n",
       " 'latent_dim': 256,\n",
       " 'save_path': '/artifacts/',\n",
       " 'data_path': '../../data/raw/v1.0/'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "settings_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data\n",
    "\n",
    "Start with batching a single file before tackling the whole dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m\u001b[36mextrapolate\u001b[m\u001b[m  \u001b[1m\u001b[36minterpolate\u001b[m\u001b[m  \u001b[1m\u001b[36mtrain-easy\u001b[m\u001b[m   \u001b[1m\u001b[36mtrain-hard\u001b[m\u001b[m   \u001b[1m\u001b[36mtrain-medium\u001b[m\u001b[m\n"
     ]
    }
   ],
   "source": [
    "raw_path = Path(settings_dict['data_path'])\n",
    "!ls {raw_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "algebra__linear_1d.txt\n",
      "algebra__linear_1d_composed.txt\n",
      "algebra__linear_2d.txt\n",
      "algebra__linear_2d_composed.txt\n",
      "algebra__polynomial_roots.txt\n"
     ]
    }
   ],
   "source": [
    "interpolate_path = raw_path/'interpolate'\n",
    "!ls {interpolate_path} | head -5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "algebra__polynomial_roots_big.txt\n",
      "arithmetic__add_or_sub_big.txt\n",
      "arithmetic__add_sub_multiple_longer.txt\n",
      "arithmetic__div_big.txt\n",
      "arithmetic__mixed_longer.txt\n"
     ]
    }
   ],
   "source": [
    "extrapolate_path = raw_path/'extrapolate'\n",
    "!ls {extrapolate_path} | head -5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "algebra__linear_1d.txt\n",
      "algebra__linear_1d_composed.txt\n",
      "algebra__linear_2d.txt\n",
      "algebra__linear_2d_composed.txt\n",
      "algebra__polynomial_roots.txt\n"
     ]
    }
   ],
   "source": [
    "train_easy_path = raw_path/'train-easy/'\n",
    "!ls {train_easy_path} | head -5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "math_module = settings_dict[\"math_module\"]\n",
    "train_level = settings_dict[\"train_level\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = {\n",
    "    'train':(raw_path, 'train-' + train_level + '/' + math_module),\n",
    "    'interpolate':(interpolate_path, math_module),\n",
    "    'extrapolate':(extrapolate_path, math_module)\n",
    "           }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "def concatenate_texts(path, pattern):\n",
    "    file_paths = list(path.glob('{}*.txt'.format(pattern)))\n",
    "    \n",
    "    input_texts = []\n",
    "    target_texts = []\n",
    "\n",
    "    for file_path in file_paths:\n",
    "        with open(str(file_path), 'r', encoding='utf-8') as f:\n",
    "            lines = f.read().split('\\n')[:-1]\n",
    "\n",
    "        input_texts.extend(['\\t' + input_text + '\\n' for input_text in lines[0::2]])\n",
    "        target_texts.extend(['\\t' + target_text + '\\n' for target_text in lines[1::2]])\n",
    "        \n",
    "    return input_texts, target_texts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of set train is 666666\n",
      "Length of set interpolate is 10000\n",
      "Length of set extrapolate is 10000\n",
      "CPU times: user 271 ms, sys: 79.1 ms, total: 350 ms\n",
      "Wall time: 350 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "input_texts = {}\n",
    "target_texts = {}\n",
    "\n",
    "for k, v in datasets.items():\n",
    "    input_texts[k], target_texts[k] = concatenate_texts(v[0], v[1])\n",
    "    print('Length of set {} is {}'.format(k, len(input_texts[k])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Sample:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INPUT: What is -2 - (1 - (3 + (5 - 3)))?\n",
      "OUTPUT: 2\n"
     ]
    }
   ],
   "source": [
    "random_idx = np.random.randint(1, len(input_texts['train']))\n",
    "print('INPUT:', input_texts['train'][random_idx])\n",
    "print('OUTPUT:', target_texts['train'][random_idx].strip())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Concatenate texts to get text metrics (max length, number of unique tokens, etc.):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_input_texts = sum(input_texts.values(), [])\n",
    "all_target_texts = sum(target_texts.values(), [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_characters = set(''.join(all_input_texts))\n",
    "target_characters = set(''.join(all_target_texts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples: 686666\n",
      "Number of unique input tokens: 32\n",
      "Number of unique output tokens: 13\n",
      "Max sequence length for inputs: 99\n",
      "Max sequence length for outputs: 6\n"
     ]
    }
   ],
   "source": [
    "input_characters = sorted(list(input_characters))\n",
    "target_characters = sorted(list(target_characters))\n",
    "num_encoder_tokens = len(input_characters)\n",
    "num_decoder_tokens = len(target_characters)\n",
    "max_encoder_seq_length = max([len(txt) for txt in all_input_texts])\n",
    "max_decoder_seq_length = max([len(txt) for txt in all_target_texts])\n",
    "\n",
    "print('Number of samples:', len(all_input_texts))\n",
    "print('Number of unique input tokens:', num_encoder_tokens)\n",
    "print('Number of unique output tokens:', num_decoder_tokens)\n",
    "print('Max sequence length for inputs:', max_encoder_seq_length)\n",
    "print('Max sequence length for outputs:', max_decoder_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Evaluate 37 + -41 - (267 + 0 + 0 + 5 + -20 + (46 - 35) - (-1 + 8)).'"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_input_texts[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\t-260\\n'"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_target_texts[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process text"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectorise the text\n",
    "Before training, we need to map strings to a numerical representation. Create two lookup tables: one mapping question characters to numbers, and another for answer characters to number."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creating a mapping from unique characters to indices\n",
    "input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])\n",
    "target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_length(tensor):\n",
    "    return max(len(t) for t in tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(lang):\n",
    "    lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(\n",
    "      filters='', char_level=True)\n",
    "    lang_tokenizer.fit_on_texts(lang)\n",
    "    \n",
    "    tensor = lang_tokenizer.texts_to_sequences(lang)\n",
    "    \n",
    "    tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,\n",
    "                                                         padding='post')\n",
    "    \n",
    "    return tensor, lang_tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_tensor, inp_lang_tokenizer = tokenize(all_input_texts)\n",
    "target_tensor, targ_lang_tokenizer = tokenize(all_target_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate max_length of the target tensors\n",
    "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 99)"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_length_targ, max_length_inp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "549332 549332 137334 137334\n"
     ]
    }
   ],
   "source": [
    "# Creating training and validation sets using an 80-20 split\n",
    "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n",
    "\n",
    "# Show length\n",
    "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert(lang, tensor):\n",
    "    for t in tensor:\n",
    "        if t!=0:\n",
    "            print (\"%d ----> %s\" % (t, lang.index_word[t]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input Language; index to word mapping\n",
      "16 ----> 0\n",
      "1 ---->  \n",
      "3 ----> +\n",
      "1 ---->  \n",
      "6 ----> (\n",
      "4 ----> 1\n",
      "4 ----> 1\n",
      "1 ---->  \n",
      "2 ----> -\n",
      "1 ---->  \n",
      "15 ----> 4\n",
      "7 ----> )\n",
      "1 ---->  \n",
      "3 ----> +\n",
      "1 ---->  \n",
      "2 ----> -\n",
      "26 ----> 7\n",
      "1 ---->  \n",
      "3 ----> +\n",
      "1 ---->  \n",
      "12 ----> 3\n",
      "\n",
      "Target Language; index to word mapping\n",
      "1 ----> \t\n",
      "6 ----> 3\n",
      "2 ----> \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print (\"Input Language; index to word mapping\")\n",
    "convert(inp_lang_tokenizer, input_tensor_train[0])\n",
    "print ()\n",
    "print (\"Target Language; index to word mapping\")\n",
    "convert(targ_lang_tokenizer, target_tensor_train[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUFFER_SIZE = len(input_tensor_train)\n",
    "BATCH_SIZE = 64\n",
    "steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n",
    "embedding_dim = 256\n",
    "units = 1024\n",
    "vocab_inp_size = len(inp_lang_tokenizer.word_index)+1\n",
    "vocab_tar_size = len(targ_lang_tokenizer.word_index)+1\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n",
    "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([64, 99]), TensorShape([64, 6]))"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_input_batch, example_target_batch = next(iter(dataset))\n",
    "example_input_batch.shape, example_target_batch.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.enc_units = enc_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = tf.keras.layers.GRU(self.enc_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "\n",
    "    def call(self, x, hidden):\n",
    "        x = self.embedding(x)\n",
    "        output, state = self.gru(x, initial_state = hidden)\n",
    "        return output, state\n",
    "\n",
    "    def initialize_hidden_state(self):\n",
    "        return tf.zeros((self.batch_sz, self.enc_units))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder output shape: (batch size, sequence length, units) (64, 99, 1024)\n",
      "Encoder Hidden state shape: (batch size, units) (64, 1024)\n"
     ]
    }
   ],
   "source": [
    "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
    "\n",
    "# sample input\n",
    "sample_hidden = encoder.initialize_hidden_state()\n",
    "sample_output, sample_hidden = encoder(example_input_batch, sample_hidden)\n",
    "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
    "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BahdanauAttention(tf.keras.Model):\n",
    "    def __init__(self, units):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.W1 = tf.keras.layers.Dense(units)\n",
    "        self.W2 = tf.keras.layers.Dense(units)\n",
    "        self.V = tf.keras.layers.Dense(1)\n",
    "\n",
    "    def call(self, query, values):\n",
    "        # hidden shape == (batch_size, hidden size)\n",
    "        # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n",
    "        # we are doing this to perform addition to calculate the score\n",
    "        hidden_with_time_axis = tf.expand_dims(query, 1)\n",
    "\n",
    "        # score shape == (batch_size, max_length, 1)\n",
    "        # we get 1 at the last axis because we are applying score to self.V\n",
    "        # the shape of the tensor before applying self.V is (batch_size, max_length, units)\n",
    "        score = self.V(tf.nn.tanh(\n",
    "            self.W1(values) + self.W2(hidden_with_time_axis)))\n",
    "\n",
    "        # attention_weights shape == (batch_size, max_length, 1)\n",
    "        attention_weights = tf.nn.softmax(score, axis=1)\n",
    "\n",
    "        # context_vector shape after sum == (batch_size, hidden_size)\n",
    "        context_vector = attention_weights * values\n",
    "        context_vector = tf.reduce_sum(context_vector, axis=1)\n",
    "\n",
    "        return context_vector, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Attention result shape: (batch size, units) (64, 1024)\n",
      "Attention weights shape: (batch_size, sequence_length, 1) (64, 99, 1)\n"
     ]
    }
   ],
   "source": [
    "attention_layer = BahdanauAttention(10)\n",
    "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
    "\n",
    "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
    "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(tf.keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.batch_sz = batch_sz\n",
    "        self.dec_units = dec_units\n",
    "        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = tf.keras.layers.GRU(self.dec_units,\n",
    "                                       return_sequences=True,\n",
    "                                       return_state=True,\n",
    "                                       recurrent_initializer='glorot_uniform')\n",
    "        self.fc = tf.keras.layers.Dense(vocab_size)\n",
    "\n",
    "        # used for attention\n",
    "        self.attention = BahdanauAttention(self.dec_units)\n",
    "\n",
    "    def call(self, x, hidden, enc_output):\n",
    "        # enc_output shape == (batch_size, max_length, hidden_size)\n",
    "        context_vector, attention_weights = self.attention(hidden, enc_output)\n",
    "\n",
    "        # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
    "        x = self.embedding(x)\n",
    "\n",
    "        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
    "        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
    "\n",
    "        # passing the concatenated vector to the GRU\n",
    "        output, state = self.gru(x)\n",
    "\n",
    "        # output shape == (batch_size * 1, hidden_size)\n",
    "        output = tf.reshape(output, (-1, output.shape[2]))\n",
    "\n",
    "        # output shape == (batch_size, vocab)\n",
    "        x = self.fc(output)\n",
    "\n",
    "        return x, state, attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decoder output shape: (batch_size, vocab size) (64, 14)\n"
     ]
    }
   ],
   "source": [
    "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
    "\n",
    "sample_decoder_output, _, _ = decoder(tf.random.uniform((64, 1)),\n",
    "                                      sample_hidden, sample_output)\n",
    "\n",
    "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
    "    from_logits=True, reduction='none')\n",
    "\n",
    "def loss_function(real, pred):\n",
    "    mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
    "    loss_ = loss_object(real, pred)\n",
    "\n",
    "    mask = tf.cast(mask, dtype=loss_.dtype)\n",
    "    loss_ *= mask\n",
    "\n",
    "    return tf.reduce_mean(loss_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(inp, targ, enc_hidden):\n",
    "    loss = 0\n",
    "\n",
    "    with tf.GradientTape() as tape:\n",
    "        enc_output, enc_hidden = encoder(inp, enc_hidden)\n",
    "\n",
    "        dec_hidden = enc_hidden\n",
    "\n",
    "        dec_input = tf.expand_dims([targ_lang_tokenizer.word_index['\\t']] * BATCH_SIZE, 1)\n",
    "\n",
    "        # Teacher forcing - feeding the target as the next input\n",
    "        for t in range(1, targ.shape[1]):\n",
    "            # passing enc_output to the decoder\n",
    "            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n",
    "\n",
    "            loss += loss_function(targ[:, t], predictions)\n",
    "\n",
    "            # using teacher forcing\n",
    "            dec_input = tf.expand_dims(targ[:, t], 1)\n",
    "\n",
    "    batch_loss = (loss / int(targ.shape[1]))\n",
    "\n",
    "    variables = encoder.trainable_variables + decoder.trainable_variables\n",
    "\n",
    "    gradients = tape.gradient(loss, variables)\n",
    "\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "    return batch_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Batch 0 Loss 1.0914\n",
      "Epoch 1 Loss 0.0001\n",
      "Time taken for 1 epoch 19.425013065338135 sec\n",
      "\n",
      "Epoch 1 Loss 0.0003\n",
      "Time taken for 1 epoch 21.822988033294678 sec\n",
      "\n",
      "Epoch 1 Loss 0.0004\n",
      "Time taken for 1 epoch 24.384517192840576 sec\n",
      "\n",
      "Epoch 1 Loss 0.0007\n",
      "Time taken for 1 epoch 26.82024908065796 sec\n",
      "\n",
      "Epoch 1 Loss 0.0007\n",
      "Time taken for 1 epoch 29.324955224990845 sec\n",
      "\n",
      "Epoch 1 Loss 0.0009\n",
      "Time taken for 1 epoch 31.91022515296936 sec\n",
      "\n",
      "Epoch 1 Loss 0.0010\n",
      "Time taken for 1 epoch 34.61877512931824 sec\n",
      "\n",
      "Epoch 1 Loss 0.0011\n",
      "Time taken for 1 epoch 37.23060607910156 sec\n",
      "\n",
      "Epoch 1 Loss 0.0012\n",
      "Time taken for 1 epoch 39.89896106719971 sec\n",
      "\n",
      "Epoch 1 Loss 0.0013\n",
      "Time taken for 1 epoch 42.680989265441895 sec\n",
      "\n",
      "Epoch 1 Loss 0.0014\n",
      "Time taken for 1 epoch 45.590203046798706 sec\n",
      "\n",
      "Epoch 1 Loss 0.0014\n",
      "Time taken for 1 epoch 48.44242000579834 sec\n",
      "\n",
      "Epoch 1 Loss 0.0015\n",
      "Time taken for 1 epoch 51.15497016906738 sec\n",
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m----------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                    Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-97-8a71a42a25b0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m         \u001b[0mbatch_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menc_hidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m         \u001b[0mtotal_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbatch_loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    402\u001b[0m       \u001b[0;31m# In this case we have created variables on the first call, so we run the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    403\u001b[0m       \u001b[0;31m# defunned version which is guaranteed to never create variables.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# pylint: disable=not-callable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    405\u001b[0m     \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    406\u001b[0m       \u001b[0;31m# In this case we have not created variables on the first call. So we can\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1333\u001b[0m     \u001b[0;34m\"\"\"Calls a graph function specialized to the inputs.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1334\u001b[0m     \u001b[0mgraph_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1335\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_filtered_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1337\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_filtered_call\u001b[0;34m(self, args, kwargs)\u001b[0m\n\u001b[1;32m    587\u001b[0m     \"\"\"\n\u001b[1;32m    588\u001b[0m     return self._call_flat(\n\u001b[0;32m--> 589\u001b[0;31m         (t for t in nest.flatten((args, kwargs), expand_composites=True)\n\u001b[0m\u001b[1;32m    590\u001b[0m          if isinstance(t, (ops.Tensor,\n\u001b[1;32m    591\u001b[0m                            resource_variable_ops.ResourceVariable))))\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m    669\u001b[0m     \u001b[0;31m# Only need to override the gradient in graph mode and when we have outputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    670\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuting_eagerly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 671\u001b[0;31m       \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_inference_function\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    672\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    673\u001b[0m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_register_gradient\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, ctx, args)\u001b[0m\n\u001b[1;32m    443\u001b[0m             attrs=(\"executor_type\", executor_type,\n\u001b[1;32m    444\u001b[0m                    \"config_proto\", config),\n\u001b[0;32m--> 445\u001b[0;31m             ctx=ctx)\n\u001b[0m\u001b[1;32m    446\u001b[0m       \u001b[0;31m# Replace empty list with None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    447\u001b[0m       \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/git/deep-math/env/lib/python3.7/site-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     59\u001b[0m     tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name,\n\u001b[1;32m     60\u001b[0m                                                \u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 61\u001b[0;31m                                                num_outputs)\n\u001b[0m\u001b[1;32m     62\u001b[0m   \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     63\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "EPOCHS = 1\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    start = time.time()\n",
    "\n",
    "    enc_hidden = encoder.initialize_hidden_state()\n",
    "    total_loss = 0\n",
    "\n",
    "    for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
    "        batch_loss = train_step(inp, targ, enc_hidden)\n",
    "        total_loss += batch_loss\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                                     batch,\n",
    "                                                     batch_loss.numpy()))\n",
    "\n",
    "        print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
    "                                      total_loss / steps_per_epoch))\n",
    "        print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
